{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Migrate torchtext from the legacy API to the new API",
      "provenance": [],
      "collapsed_sections": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/zhangguanheng66/text/blob/migration_tutorial/examples/legacy_tutorial/migration_tutorial.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HMGWxQCO7s0e"
      },
      "source": [
        "!pip install -U torch==1.8.0 torchtext==0.9.0\n",
        "\n",
        "# Reload environment\n",
        "exit()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jXUgsnxw70-M"
      },
      "source": [
        "This is a tutorial to show how to migrate from the legacy API in torchtext to the new API in 0.9.0 release. Here, we take the IMDB dataset as an example for the sentiment analysis. Both legacy and new APIs in torchtext can preprocess the text input and prepare the data to train/validate a model with the following steps:\n",
        "\n",
        "*   Train/validate/test split: generate train/validate/test data set if they are available\n",
        "*   Tokenization: break a raw text string sentence into a list of words\n",
        "*   Vocab: define a \"contract\" from tokens to indexes\n",
        "*   Numericalize: convert a list of tokens to the corresponding indexes\n",
        "*   Batch: generate batches of data samples and add padding if necessary\n",
        "\n",
        "It should be noted that all the legacy features are still available, but within torchtext.legacy instead of torchtext."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WWRW4bsL8UL0"
      },
      "source": [
        "## Step 1: Create a dataset object\n",
        "----------------------------\n",
        "\n",
        "Fist of all, we create a dataset for the sentiment analysis. The individual data sample contains a label and a text string.\n",
        "\n",
        "### *Legacy*\n",
        "In the legacy code, `Field` class is used for data processing, including tokenizer and numberzation. To check out the dataset, users need to first set up the TEXT/LABEL fields."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FttPxcbc70j1"
      },
      "source": [
        "import torchtext\n",
        "import torch\n",
        "from torchtext.legacy import data\n",
        "from torchtext.legacy import datasets\n",
        "\n",
        "TEXT = data.Field()\n",
        "LABEL = data.LabelField(dtype = torch.long)\n",
        "legacy_train, legacy_test = datasets.IMDB.splits(TEXT, LABEL)  # datasets here refers to torchtext.legacy.datasets"
      ],
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ssXfxJJSq7WT"
      },
      "source": [
        "You can print out the raw data by checking out Dataset.examples. The entire text data are stored as a list of tokens."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7DRXJFgzriaH"
      },
      "source": [
        "legacy_examples = legacy_train.examples\n",
        "print(legacy_examples[0].text, legacy_examples[0].label)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eQMfAN_Fz3aa"
      },
      "source": [
        "### *New*\n",
        "The new dataset API returns the train/test dataset split directly without the preprocessing information. Each split is an iterator which yields the raw texts and labels line-by-line."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YHUYZ7yt0Lb5"
      },
      "source": [
        "from torchtext.datasets import IMDB\n",
        "train_iter, test_iter = IMDB(split=('train', 'test'))"
      ],
      "execution_count": 9,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yB7MShEBsd3P"
      },
      "source": [
        "To print out the raw data, you can call the next() function on the IterableDataset."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wUkWE1KWsPqy"
      },
      "source": [
        "next(train_iter)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ycL7xqRP0eLU"
      },
      "source": [
        "## Step 2 Build the data processing pipeline\n",
        "----------------------------\n",
        "\n",
        "### *Legacy*\n",
        "\n",
        "The default tokenizer implemented in the `Field` class is the built-in python `split()` function. Users choose the tokenizer by calling `data.get_tokenizer()`, and add it to the `Field` constructor. For the sequence model, it's common to append `<BOS>` (begin-of-sentence) and `<EOS>` (end-of-sentence) tokens, and the special tokens need to be defined in the `Field` class."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8H_I_XW8gSR1"
      },
      "source": [
        "TEXT = data.Field(tokenize=data.get_tokenizer('basic_english'),\n",
        "                  init_token='<SOS>', eos_token='<EOS>', lower=True)\n",
        "LABEL = data.LabelField(dtype = torch.long)\n",
        "legacy_train, legacy_test = datasets.IMDB.splits(TEXT, LABEL)  # datasets here refers to torchtext.legacy.datasets"
      ],
      "execution_count": 11,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "opQ6LcnigTKx"
      },
      "source": [
        "Now you can create a vocabulary of the words from the text file stored in the predefined `Field` object, `TEXT`. You fist have to build a vocabulary in your `Field` object by passing the dataset to the `build_vocab` func. The Field object builds the vocabulary (`TEXT.vocab`) on a specific data split."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Cffl6ueN8T5X"
      },
      "source": [
        "TEXT.build_vocab(legacy_train)\n",
        "LABEL.build_vocab(legacy_train)"
      ],
      "execution_count": 12,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OXQ9rmiHt58H"
      },
      "source": [
        "Things you can do with a vocabuary object\n",
        "\n",
        "\n",
        "*   Total length of the vocabulary\n",
        "*   String2Index (stoi) and Index2String (itos)\n",
        "*   A purpose-specific vocabulary which contains word appearing more than N times\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YzweKLh5uSNC"
      },
      "source": [
        "legacy_vocab = TEXT.vocab\n",
        "print(\"The length of the legacy vocab is\", len(legacy_vocab))\n",
        "legacy_stoi = legacy_vocab.stoi\n",
        "print(\"The index of 'example' is\", legacy_stoi['example'])\n",
        "legacy_itos = legacy_vocab.itos\n",
        "print(\"The token at index 686 is\", legacy_itos[686])\n",
        "\n",
        "# Set up the mim_freq value in the Vocab class\n",
        "TEXT.build_vocab(legacy_train, min_freq=10)\n",
        "legacy_vocab2 = TEXT.vocab\n",
        "print(\"The length of the legacy vocab is\", len(legacy_vocab2))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LXTibHc00olW"
      },
      "source": [
        "### *New*\n",
        "\n",
        "Users have the access to different kinds of tokenizers directly via `data.get_tokenizer()` function."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QavK23zjhNlx"
      },
      "source": [
        "from torchtext.data.utils import get_tokenizer\n",
        "tokenizer = get_tokenizer('basic_english')"
      ],
      "execution_count": 14,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TnNpf4mWF5pe"
      },
      "source": [
        "To have more flexibility, users can build the vocabulary directly with the Vocab class. For example, the argument `min_freq` is to set up the cutoff frequency to in the vocabulary. The special tokens, like `<BOS>` and `<EOS>` can be assigned to the special symbols in the constructor of the Vocab class."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Ro8HXPwmwtp7"
      },
      "source": [
        "from collections import Counter\n",
        "from torchtext.vocab import Vocab\n",
        "\n",
        "train_iter = IMDB(split='train')\n",
        "counter = Counter()\n",
        "for (label, line) in train_iter:\n",
        "    counter.update(tokenizer(line))\n",
        "vocab = Vocab(counter, min_freq=10, specials=('<unk>', '<BOS>', '<EOS>', '<PAD>'))"
      ],
      "execution_count": 15,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xGuqqa7CxLq8"
      },
      "source": [
        "print(\"The length of the new vocab is\", len(vocab))\n",
        "new_stoi = vocab.stoi\n",
        "print(\"The index of '<BOS>' is\", new_stoi['<BOS>'])\n",
        "new_itos = vocab.itos\n",
        "print(\"The token at index 2 is\", new_itos[2])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "l31FBekVr9j8"
      },
      "source": [
        "Both `text_transform` and `label_transform` are the callable object, such as a lambda func here, to process the raw text and label data from the dataset iterators. Users can add the special symbols `<BOS>` and `<EOS>` to the sentence in `text_transform`."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ez2lT2QO0sNj"
      },
      "source": [
        "text_transform = lambda x: [vocab['<BOS>']] + [vocab[token] for token in tokenizer(x)] + [vocab['<EOS>']]\n",
        "label_transform = lambda x: 1 if x == 'pos' else 0\n",
        "\n",
        "# Print out the output of text_transform\n",
        "print(\"input to the text_transform:\", \"here is an example\")\n",
        "print(\"output of the text_transform:\", text_transform(\"here is an example\"))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4dEG7pyi1ElM"
      },
      "source": [
        "## Step 3: Generate batch iterator\n",
        "--------------------------------\n",
        "\n",
        "To train a model efficiently, it's recommended to build an iterator to generate data batch.\n",
        "\n",
        "### *Legacy*\n",
        "The legacy `Iterator` class is used to batch the dataset and send to the target device, like CPU or GPU."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NN67ofUB-sz1"
      },
      "source": [
        "import torch\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "legacy_train, legacy_test = datasets.IMDB.splits(TEXT, LABEL)  # datasets here refers to torchtext.legacy.datasets\n",
        "legacy_train_iterator, legacy_test_iterator = data.Iterator.splits(\n",
        "    (legacy_train, legacy_test), batch_size=8, device = device)"
      ],
      "execution_count": 18,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vBMjFVvsMPqR"
      },
      "source": [
        "For a NLP workflow, it's also common to define an iterator and batch texts with similar lengths together. The legacy `BucketIterator` class in torchtext library minimizes the amount of padding needed."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PgC6dhDqMOjp"
      },
      "source": [
        "from torchtext.legacy.data import BucketIterator\n",
        "legacy_train, legacy_test = datasets.IMDB.splits(TEXT, LABEL)\n",
        "legacy_train_bucketiterator, legacy_test_bucketiterator = data.BucketIterator.splits(\n",
        "    (legacy_train, legacy_test),\n",
        "    sort_key=lambda x: len(x.text),\n",
        "    batch_size=8, device = device)"
      ],
      "execution_count": 19,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kBV-Wvlo07ye"
      },
      "source": [
        "### *New*\n",
        "\n",
        "`torch.utils.data.DataLoader` is used to generate data batch. Users could customize the data batch by defining a function with the `collate_fn` argument in the DataLoader. Here, in the `collate_batch` func, we process the raw text data and add padding to dynamically match the longest sentence in a batch."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EC054Wlr0-xB"
      },
      "source": [
        "from torch.utils.data import DataLoader\n",
        "from torch.nn.utils.rnn import pad_sequence\n",
        "\n",
        "def collate_batch(batch):\n",
        "   label_list, text_list = [], []\n",
        "   for (_label, _text) in batch:\n",
        "        label_list.append(label_transform(_label))\n",
        "        processed_text = torch.tensor(text_transform(_text))\n",
        "        text_list.append(processed_text)\n",
        "   return torch.tensor(label_list), pad_sequence(text_list, padding_value=3.0)\n",
        "\n",
        "train_iter = IMDB(split='train')\n",
        "train_dataloader = DataLoader(list(train_iter), batch_size=8, shuffle=True, \n",
        "                              collate_fn=collate_batch)"
      ],
      "execution_count": 20,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Jky4X-iFU4HK"
      },
      "source": [
        "To group the texts with similar length together, like introduced in the legacy `BucketIterator` class, first of all, we randomly create multiple \"pools\", and each of them has a size of `batch_size * 100`. Then, we sort the samples within the individual pool by length. This idea can be implemented succintly through `batch_sampler` argument of PyTorch `Dataloader`. `batch_sampler` accepts 'Sampler' or Iterable object that yields indices of next batch. In the code below, we implemented a generator that yields batch of indices for which the corresponding batch of data is of similar length. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zCvxeLbYW3I_"
      },
      "source": [
        "import random\n",
        "\n",
        "train_iter = IMDB(split='train')\n",
        "train_list = list(train_iter)\n",
        "batch_size = 8  # A batch size of 8\n",
        "\n",
        "def batch_sampler():\n",
        "    indices = [(i, len(tokenizer(s[1]))) for i, s in enumerate(train_list)]\n",
        "    random.shuffle(indices)\n",
        "    pooled_indices = []\n",
        "    # create pool of indices with similar lengths \n",
        "    for i in range(0, len(indices), batch_size * 100):\n",
        "        pooled_indices.extend(sorted(indices[i:i + batch_size * 100], key=lambda x: x[1]))\n",
        "\n",
        "    pooled_indices = [x[0] for x in pooled_indices]\n",
        "\n",
        "    # yield indices for current batch\n",
        "    for i in range(0, len(pooled_indices), batch_size):\n",
        "        yield pooled_indices[i:i + batch_size]\n",
        "\n",
        "bucket_dataloader = DataLoader(train_list, batch_sampler=batch_sampler(),\n",
        "                               collate_fn=collate_batch)\n",
        "\n",
        "print(next(iter(bucket_dataloader)))"
      ],
      "execution_count": 24,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0wrbC_v01Ib9"
      },
      "source": [
        "## Step 4: Iterate batch to train a model\n",
        "-------------------------------\n",
        "\n",
        "It's almost same for both legacy and new APIs to iterate the data for batches during training and validating a model.\n",
        "\n",
        "### *Legacy*\n",
        "\n",
        "The legacy batch iterator can be iterated or executed with `next()` method."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "X_tml54u-6AS"
      },
      "source": [
        "# for item in legacy_train_iterator:\n",
        "#   model(item)\n",
        "\n",
        "# Or\n",
        "next(iter(legacy_train_iterator))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sRTvfxMB1P2P"
      },
      "source": [
        "### *New*\n",
        "\n",
        "The batch iterator can be iterated or executed with `next()` method."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iTotRtXe1CWn"
      },
      "source": [
        "# for idx, (label, text) in enumerate(train_dataloader):\n",
        "#   model(item)\n",
        "\n",
        "# Or\n",
        "next(iter(train_dataloader))"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}
